import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter



# 准备测试数据集
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor())
test_load = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

# img, target = test_data[0]
# print(img.shape)

writer = SummaryWriter("dataloader")
for epoch in range(2):
    step = 0
    for data in test_load:
        imgs, targets = data
        writer.add_images("Epoch {}".format(epoch), imgs, step)
        step+=1

writer.close()
